PyTorch Lightning教程二:验证、测试、checkpoint、早停策略

您所在的位置:网站首页 pytorch 暂停训练 PyTorch Lightning教程二:验证、测试、checkpoint、早停策略

PyTorch Lightning教程二:验证、测试、checkpoint、早停策略

2023-11-15 09:20| 来源: 网络整理| 查看: 265

介绍:上一期介绍了如何利用PyTorch Lightning搭建并训练一个模型(仅使用训练集),为了保证模型可以泛化到未见过的数据上,数据集通常被分为训练和测试两个集合,测试集与训练集相互独立,用以测试模型的泛化能力。本期通过增加验证和测试集来达到该目的,同时,还引入checkpoint和早停策略,以得到模型最佳权重。

相关链接:https://lightning.ai/docs/pytorch/stable/levels/basic_level_2.html

训练集、验证集、测试集的使用 1.添加依赖,获取训练集和测试集

添加相应的依赖,同时使用MNIST数据集,获取训练和测试集

import torch.utils.data as data from torchvision import datasets import torchvision.transforms as transforms from torch.utils.data import DataLoader # 加载数据(测试集,train=False) transform = transforms.ToTensor() train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform) test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform) 2.实现并调用test_step

在定义LightningModule中,实现test_step方法;在外部,调用test方法

class LitAutoEncoder(pl.LightningModule): def training_step(self, batch, batch_idx): ... def test_step(self, batch, batch_idx): # 测试,该方法与training_step相似 x, y = batch x = x.view(x.size(0), -1) z = self.encoder(x) x_hat = self.decoder(z) test_loss = F.mse_loss(x_hat, x) self.log("test_loss", test_loss) # 初始化Trainer trainer = Trainer() # 执行test方法 trainer.test(model, dataloaders=DataLoader(test_set)) 3.实现并调用验证集

通常使用torch.utils.data中的方法,将训练集中的一部分数据化为验证集

# 训练集中的20%数据划为验证集 train_set_size = int(len(train_set) * 0.8) valid_set_size = len(train_set) - train_set_size # 拆分,使用data.random_split方法 seed = torch.Generator().manual_seed(42) train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)

与测试集一样,需要在定义LightningModule中,实现validation_step方法;在外部,调用fit方法

class LitAutoEncoder(pl.LightningModule): def training_step(self, batch, batch_idx): ... def validation_step(self, batch, batch_idx): x, y = batch x = x.view(x.size(0), -1) z = self.encoder(x) x_hat = self.decoder(z) val_loss = F.mse_loss(x_hat, x) self.log("val_loss", val_loss) def test_step(self, batch, batch_idx): ... # 调用torch.utils.data中的DataLoader对训练和测试集进行封装 train_loader = DataLoader(train_set) valid_loader = DataLoader(valid_set) # 在fit方法中,引入valid_loader,即验证集 trainer = Trainer() trainer.fit(model, train_loader, valid_loader) checkpoint

checkpoint有两个作用,一是能得到每一次epoch后的模型权重,能得到最佳表现的权重;二是能够在中断或停止后,继续在当前checkpoint处,继续训练。在Lightning中的checkpoint,包含模型的整个内部状态,这与普通的PyTorch不同,即使在最复杂的分布式训练环境中,Lightning也可以保存恢复模型所需的一切。包含以下状态:

16-bit scaling factor (若使用16精度训练)Current epochGlobal stepLightningModule’s state_dictState of all optimizersState of all learning rate schedulersState of all callbacks (for stateful callbacks)State of datamodule (for stateful datamodules)The hyperparameters (init arguments) with which the model was createdThe hyperparameters (init arguments) with which the datamodule was createdState of Loops 保存与调用方法 # 保存方法,可自定义default_root_dir路径,若不设置路径,将会自动保存 trainer = Trainer(default_root_dir="some/path/") # 调用方法 model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt") model.eval() # disable randomness, dropout, etc... y_hat = model(x)

调用,还可以使用torch的方法

checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage) print(checkpoint["hyper_parameters"]) # {"learning_rate": the_value, "another_parameter": the_other_value}

也可以实现重现,例如模型LitModel(in_dim=32, out_dim=10)

# 使用 in_dim=32, out_dim=10 model = LitModel.load_from_checkpoint(PATH) # 使用 in_dim=128, out_dim=10 model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)

Lightning和PyTorch完全兼容

checkpoint = torch.load(CKPT_PATH) encoder_weights = checkpoint["encoder"] decoder_weights = checkpoint["decoder"]

设置checkpoint不可见

trainer = Trainer(enable_checkpointing=False)

如果想全部重新恢复

model = LitModel() trainer = Trainer()

自动恢复所有相关参数 model, epoch, step, LR schedulers, etc…

trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt") 早停策略 EarlyStopping Callback

在Lightning中,早停回调步骤如下:

Import EarlyStopping callback. 载入EarlyStopping回调方法Log the metric you want to monitor using log() method. 加载日志方法Init the callback, and set monitor to the logged metric of your choice. 设置monitorSet the mode based on the metric needs to be monitored. 设置modePass the EarlyStopping callback to the Trainer callbacks flag. 调入EarlyStropping from lightning.pytorch.callbacks.early_stopping import EarlyStopping class LitModel(LightningModule): def validation_step(self, batch, batch_idx): loss = ... self.log("val_loss", loss) model = LitModel() trainer = Trainer(callbacks=[EarlyStopping(monitor="val_loss", mode="min")]) trainer.fit(model) # 也可以使用自定义的EarlyStopping策略 early_stop_callback = EarlyStopping(monitor="val_accuracy", min_delta=0.00, patience=3, verbose=False, mode="max") trainer = Trainer(callbacks=[early_stop_callback]) # EarlyStopping的文档链接https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html#lightning.pytorch.callbacks.EarlyStopping 注意 EarlyStopping默认在一次Validation后调用,但是Validation可以自定义多少次epoch后进行一次验证,例如check_val_every_n_epoch and val_check_interval。 完整代码 # coding:utf-8 import torch import torch.nn as nn import torch.utils.data as data from torchvision import datasets import torchvision.transforms as transforms from torch.utils.data import DataLoader import torch.nn.functional as F import lightning as L # -------------------------------- # Step 1: 定义模型 # -------------------------------- class LitAutoEncoder(L.LightningModule): def __init__(self): super().__init__() self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3)) self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28)) def training_step(self, batch, batch_idx): x, y = batch x = x.view(x.size(0), -1) z = self.encoder(x) x_hat = self.decoder(z) loss = F.mse_loss(x_hat, x) self.log("train_loss", loss) return loss def test_step(self, batch, batch_idx): # 测试,该方法与training_step相似 x, y = batch x = x.view(x.size(0), -1) z = self.encoder(x) x_hat = self.decoder(z) test_loss = F.mse_loss(x_hat, x) self.log("test_loss", test_loss) def validation_step(self, batch, batch_idx): x, y = batch x = x.view(x.size(0), -1) z = self.encoder(x) x_hat = self.decoder(z) val_loss = F.mse_loss(x_hat, x) self.log("val_loss", val_loss) def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) return optimizer def forward(self, x): # forward 定义了一次 预测/推理 行为 embedding = self.encoder(x) return embedding # -------------------------------- # Step 2: 加载数据+模型 # -------------------------------- transform = transforms.ToTensor() train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform) test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform) # 训练集中的20%数据划为验证集 train_set_size = int(len(train_set) * 0.8) valid_set_size = len(train_set) - train_set_size # 拆分,使用data.random_split方法 seed = torch.Generator().manual_seed(42) train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed) train_loader = DataLoader(train_set) valid_loader = DataLoader(valid_set) autoencoder = LitAutoEncoder() # -------------------------------- # Step 3: 训练+验证+测试 # -------------------------------- # 训练+验证 trainer = L.Trainer(default_root_dir="some/path/") # 这里自定义需要保存的路径 trainer.fit(autoencoder, train_loader, valid_loader) # 测试 trainer.test(autoencoder, dataloaders=DataLoader(test_set))


【本文地址】


今日新闻


推荐新闻


    CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3